"""
JSON extraction and validation module for cognitive maps.

This module provides functions to:
1. Extract JSON from text responses
2. Validate JSON structure for cognitive maps
3. Determine the format type (complex vs. simple)
"""

import json
import re
from typing import Dict, List, Tuple, Optional, Any, Union

def extract_json_from_text(text: str) -> Optional[Dict]:
    """
    Extract JSON cognitive map from text response.
    Returns the JSON object if found, otherwise None.
    
    Args:
        text: Text containing a JSON object
        
    Returns:
        Extracted JSON object or None
    """
    if not text:
        return None
        
    # Look for JSON pattern with { } brackets
    pattern = r'\{[\s\S]*\}'
    matches = re.findall(pattern, text)
    
    if not matches:
        return None
    
    # If multiple matches, select the longest one
    matches.sort(key=len, reverse=True)
    json_str = matches[0]
    
    # Try direct JSON parsing first
    try:
        return json.loads(json_str)
    except json.JSONDecodeError:
        # Try to clean up and parse again
        return clean_and_parse_json(json_str)

def clean_and_parse_json(json_str: str) -> Optional[Dict]:
    """
    Attempt to clean and parse a malformed JSON string.
    
    Args:
        json_str: A potentially malformed JSON string
        
    Returns:
        Parsed JSON object or None
    """
    try:
        # Remove comments
        clean_json = re.sub(r'//.*', '', json_str)
        # Remove newlines, tabs
        clean_json = re.sub(r'[\n\r\t]', ' ', clean_json)
        
        # Fix unquoted keys
        clean_json = re.sub(r'(\s*?)(\w+)(\s*?):', r'\1"\2"\3:', clean_json)
        # Fix trailing commas
        clean_json = re.sub(r',\s*}', '}', clean_json)
        clean_json = re.sub(r',\s*]', ']', clean_json)
        
        return json.loads(clean_json)
    except:
        # As a final attempt, try to extract in "key-value" format
        try:
            # Extract pairs like "object_name": { "position": [...], "facing": ... }
            pairs_pattern = r'"([^"]+)":\s*{([^{}]*(?:{[^{}]*}[^{}]*)*)}'
            pairs = re.findall(pairs_pattern, json_str)
            
            if pairs:
                result = {}
                for key, value in pairs:
                    try:
                        # Parse the value part
                        value_str = '{' + value + '}'
                        # Fix unquoted keys
                        value_str = re.sub(r'(\s*?)(\w+)(\s*?):', r'\1"\2"\3:', value_str)
                        # Fix trailing commas
                        value_str = re.sub(r',\s*}', '}', value_str)
                        value_str = re.sub(r',\s*]', ']', value_str)
                        
                        value_obj = json.loads(value_str)
                        result[key] = value_obj
                    except:
                        continue
                
                if result:
                    return result
        except:
            pass
        
        return None

def is_complex_format(cogmap: Dict) -> bool:
    """
    Determine if the cognitive map uses complex format (with objects/views arrays)
    or simple key-value format.
    
    Args:
        cogmap: The cognitive map JSON
        
    Returns:
        True if complex format, False if simple format
    """
    # Check if cogmap is a dictionary
    if not isinstance(cogmap, dict):
        return False
        
    return "objects" in cogmap and isinstance(cogmap.get("objects"), list)

def is_valid_position(position: Any) -> bool:
    """
    Check if a position value is valid (a list of 2 numeric values).
    
    Args:
        position: The position value to check
        
    Returns:
        True if valid, False otherwise
    """
    if not isinstance(position, list):
        return False
    
    if len(position) < 2:
        return False
    
    try:
        # Check if first two elements are numeric
        float(position[0])
        float(position[1])
        return True
    except (ValueError, TypeError):
        return False

def truncate_position_list_into_one(positions: List[Dict] | Dict) -> Dict:
    """
    Truncate a list of positions into a single position.
    """
    if isinstance(positions, list):
        return positions[0]
    return positions

def trucate_object_position(raw_cogmap: Dict) -> Dict:
    """
    Truncate a list of positions into a single position.
    """
    # Check if raw_cogmap is a dictionary
    if not isinstance(raw_cogmap, dict):
        return {}
        
    return {k: truncate_position_list_into_one(v) for k, v in raw_cogmap.items()}

def is_valid_facing(facing: Any) -> bool:
    """
    Check if a facing value is valid (one of: up, down, left, right, inner, outer).
    
    Args:
        facing: The facing value to check
        
    Returns:
        True if valid, False otherwise
    """
    if facing is None:
        return True  # Facing is optional
        
    if isinstance(facing, list):
        if not facing:
            return True
        facing = facing[0]
    
    if not isinstance(facing, str):
        return False
    
    # Normalize
    facing = facing.lower().strip()
    
    valid_facings = {
        "up", "down", "left", "right", "inner", "outer", 
        "top", "bottom", "north", "south", "east", "west",
        "front", "back", "into", "out", "inside", "outside",
        "forward", "backward"
    }
    
    return facing in valid_facings

def validate_cogmap_format(cogmap: Dict) -> Tuple[bool, List[str]]:
    """
    Validate if a cognitive map has the correct format.
    
    Args:
        cogmap: The cognitive map to validate
        
    Returns:
        Tuple of (is_valid, error_messages)
    """
    if not isinstance(cogmap, dict):
        return False, ["Cognitive map is not a dictionary"]
    
    errors = []
    
    # Check format type
    is_complex = is_complex_format(cogmap)
    
    if is_complex:
        # Validate complex format (objects/views)
        if not isinstance(cogmap.get("objects", []), list):
            errors.append("'objects' field is not a list")
        
        # Check each object
        for i, obj in enumerate(cogmap.get("objects", [])):
            if not isinstance(obj, dict):
                errors.append(f"Object {i} is not a dictionary")
                continue
                
            if "name" not in obj:
                errors.append(f"Object {i} is missing 'name' field")
            
            if "position" in obj and not is_valid_position(obj["position"]):
                errors.append(f"Object {i} has invalid 'position' format")
            
            if "facing" in obj and not is_valid_facing(obj["facing"]):
                errors.append(f"Object {i} has invalid 'facing' value")
        
        # Check views (if present)
        for i, view in enumerate(cogmap.get("views", [])):
            if not isinstance(view, dict):
                errors.append(f"View {i} is not a dictionary")
                continue
                
            if "name" not in view:
                errors.append(f"View {i} is missing 'name' field")
            
            if "position" in view and not is_valid_position(view["position"]):
                errors.append(f"View {i} has invalid 'position' format")
            
            if "facing" in view and not is_valid_facing(view["facing"]):
                errors.append(f"View {i} has invalid 'facing' value")
    else:
        # Validate simple format (key-value)
        for obj_name, obj_data in cogmap.items():
            if not isinstance(obj_data, dict):
                errors.append(f"Object '{obj_name}' is not a dictionary")
                continue
            
            if "position" in obj_data and not is_valid_position(obj_data["position"]):
                errors.append(f"Object '{obj_name}' has invalid 'position' format")
            
            if "facing" in obj_data and not is_valid_facing(obj_data["facing"]):
                errors.append(f"Object '{obj_name}' has invalid 'facing' value")
            
            # Check for unknown fields
            unknown_fields = [f for f in obj_data.keys() if f not in ["position", "facing"]]
            if unknown_fields:
                errors.append(f"Object '{obj_name}' has unknown fields: {', '.join(unknown_fields)}")
    
    return len(errors) == 0, errors

def determine_format_from_filename(filename: str) -> str:
    """
    Determine the expected format based on filename.
    
    Args:
        filename: The filename to check
        
    Returns:
        'simple' if filename contains 'shorten', 'complex' otherwise
    """
    return "simple" if "shorten" in filename.lower() else "complex"

def extract_answer(text: str) -> Optional[str]:
    """
    Extract the answer from model response text using regular expressions.
    Returns the last occurrence of the letter of the answer (A, B, C, D, or E)
    based on pattern priority - tries higher priority patterns first.
    
    Args:
        text: The model response text
        
    Returns:
        The last answer letter found by the highest priority matching pattern,
        or None if not found
    """
    if not text:
        return None
    
    # Patterns in order of priority (higher priority first)
    patterns = [
        r'(?:Answer: )?([A-E])\. [A-Za-z0-9 \-\(\)\'",]+(?=(?:\n|$|\.|"))',  # Full answer with description
        r'(?:Answer: )?([A-E])\. [A-Za-z0-9 \-\(\)\'"]+',  # Answer with partial description
        r'(?:Answer: )?([A-E])(?:\.|$|\s)',  # Answer with punctuation or whitespace
        r'(?:^|\n)(?:Answer: )?([A-E])(?:\.|$|\s)',  # Answer at line beginning
        r'[\*\"]([A-E])[\*\"]',  # Answer in quotes or asterisks
        r'\bAnswer:?\s*([A-E])\b',  # Answer following "Answer:"
    ]
    
    # Try each pattern in order of priority
    for pattern in patterns:
        matches = list(re.finditer(pattern, text))
        if matches:
            # Return the last match found by this pattern
            return matches[-1].group(1)
    
    # If none of the priority patterns match, try line-by-line parsing
    # First, try the more specific pattern on each line
    lines = text.split('\n')
    line_matches = []
    
    for i, line in enumerate(lines):
        # Look for full answer pattern in each line
        match = re.search(r'([A-E])\. [A-Za-z0-9 \-\(\)\'",]+', line)
        if match:
            line_matches.append((i, match.group(1)))
    
    if line_matches:
        # Return the answer from the last line that matched
        return line_matches[-1][1]
    
    # Finally, try the most general pattern on each line
    for i in reversed(range(len(lines))):  # Start from bottom
        line = lines[i]
        match = re.search(r'\b([A-E])\b', line)
        if match:
            return match.group(1)
    
    return None  # No answer found

def get_setting_from_id(item_id: str) -> str:
    """
    Determine the setting category based on the item ID.
    Focuses on four categories: around, rotation, translation, among.
    
    Args:
        item_id: The item identifier string
        
    Returns:
        Setting category ('around', 'rotation', 'translation', 'among', or 'other')
    """
    if 'around' in item_id.lower():
        return 'around'
    elif 'rotation' in item_id.lower():
        return 'rotation'
    elif 'translation' in item_id.lower():
        return 'translation'
    elif 'among' in item_id.lower():
        return 'among'
    else:
        return 'other'

def extract_model_and_version(filename: str) -> Tuple[str, str, bool]:
    """
    Extract model name and version from filename.
    
    Args:
        filename: The filename to extract information from
        
    Returns:
        Tuple of (model_name, version, gen_cogmap)
    """
    import os
    import re
    
    basename = os.path.basename(filename)
    model_name = "unknown"
    version = "unknown"
    
    # Pattern matching for common model names
    if "gpt4o_mini" in basename:
        model_name = "gpt4o_mini"
    elif "gpt41nano" in basename:
        model_name = "gpt41_nano"
    elif "qwen_25_vl_7b" in basename:
        model_name = "qwen_25_vl_7b"
    elif "qwen_25_vl_3b" in basename:
        model_name = "qwen_25_vl_3b"
    
    # Extract version (v2, v3, etc.)
    version_match = re.search(r'v(\d+)', basename)
    if version_match:
        version = f"v{version_match.group(1)}"
    elif not any(f"v{i}" in basename for i in range(2, 10)):
        version = "v1"  # Default to v1 if no version found
    
    # Check additional flags
    gen_cogmap = "gen_cogmap" in basename or "cogmap_gen" in basename
    
    return model_name, version, gen_cogmap

def determine_answer_field(eval_type: str) -> str:
    """
    Determine which fields in the item contain the answers.
    
    Args:
        item: The evaluation item dictionary
        eval_type: The type of evaluation (e.g., 'cogmap_gen', 'gen', 'plain')
    Returns:
        str: The field name containing the answer
    """
    # Determine which format this file is using
    # Check for different field names that might be used
    if eval_type == 'cogmap_gen':
        return 'cogmap_gen_answer'
    elif eval_type == 'cogmap_input':
        return "cogmap_answer"
    else:
        return "plain_answer"



# ==============================
#  ---- Test the functions ----

def test_extract_json_from_text():
    print("Testing extract_json_from_text")

    raw_text = "<CogMap>\n{\n    \"card_catalog_cabinet\": {\"position\": [1, 4]},\n    \"bookcase\": {\"position\": [4, 2]},\n    \"long_table\": {\"position\": [5, 6]},\n    \"several_bookcases\": [{\"position\": [3, 2]}, {\"position\": [7, 2]}, {\"position\": [8, 2]}]\n}\n<Reasoning>\n1. Analyzing Image 4 (the starting viewpoint):  \n   - The image shows large card catalog cabinets along the right side of the scene, likely near the top of the grid due to the perspective.\n   - The long table, with chairs, is positioned toward the bottom part of the scene.\n   - The surrounding environment is a small room with multiple wooden bookshelves lining the walls.\n\n2. Analyzing Image 1 (rotated 90° from Image 4):  \n   - This shows the same scene with the cabinets now on the top side of the view.\n   - The long table and chairs are visible at the bottom, confirming their placement in the scene.\n   - Multiple bookcases are observed along the left side, with some indicated as several book cases in the scene.\n\n3. Analyzing Image 2 (another 90° rotation):  \n   - It indicates a corner view where the bookcases are prominent along one wall close to the top.\n   - The cabinets and table are consistent with previous placements.\n\n4. Analyzing Image 3 (final 90° turn):  \n   - Provides the opposite wall view, where bookshelves are again visible, confirming placement.\n   - The room layout is consistent with the previous observations.\n\n5. Spatial relationships:  \n   - The card catalog cabinet, based on the first image, is at the top right, so roughly around grid position [1, 4].\n   - Multiple bookcases are along the top part of the room, roughly spread around [3, 2], [7, 2], and [8, 2].\n   - The long table is toward the bottom of the scene, around [5, 6].\n\n6. Orientation and relative positions:  \n   - When standing at image 4's viewpoint, turning left 90° would face toward the height of the room, bringing the shelf area into view.\n   - The shelf area strongly resembles the several bookcases, which are on the wall to the left when facing from that position, roughly at [3, 2].\n\n7. Conclusion:  \n   - The object to the immediate left after turning 90° from the viewpoint in image 4 is **several bookcases**.\n   \n8. The options provided:  \n   A. Long table  \n   B. Bookcase  \n   C. Several bookcases  \n   D. Card catalog cabinet  \n\n   The correct answer based on the spatial reasoning is **C. Several bookcases**.\n   \nTherefore, my answer is <Answer> C. Several bookcases"

    json_obj = extract_json_from_text(raw_text)
    print(json_obj)


def test_is_complex_format():
    print("Testing is_complex_format")
    json_obj = {
        "objects": [
            {"name": "object1", "position": [1, 2], "facing": "up"},
            {"name": "object2", "position": [3, 4], "facing": "down"}
        ],
        "views": [
            {"name": "view1", "position": [1, 2], "facing": "up"},
            {"name": "view2", "position": [3, 4], "facing": "down"}
        ]
    }

    print(is_complex_format(json_obj))

    json_obj = {
        "card_catalog_cabinet": {"position": [1, 4]},
        "bookcase": {"position": [4, 2]},
        "long_table": {"position": [5, 6]},
        "several_bookcases": [{"position": [3, 2]}, {"position": [7, 2]}, {"position": [8, 2]}]
    }

    print(is_complex_format(json_obj))

def test_is_valid_facing():
    print("Testing is_valid_facing")
    print(f"is_valid_facing('up'): {is_valid_facing('up')}")
    print(f"is_valid_facing('down'): {is_valid_facing('down')}")
    print(f"is_valid_facing('left'): {is_valid_facing('left')}")
    print(f"is_valid_facing('right'): {is_valid_facing('right')}")
    print(f"is_valid_facing('inner'): {is_valid_facing('inner')}")
    print(f"is_valid_facing('outer'): {is_valid_facing('outer')}")
    print(f"is_valid_facing('top'): {is_valid_facing('top')}")
    print(f"is_valid_facing('bottom'): {is_valid_facing('bottom')}")
    print(f"is_valid_facing('north'): {is_valid_facing('north')}")
    print(f"is_valid_facing('south'): {is_valid_facing('south')}")
    print(f"is_valid_facing('east'): {is_valid_facing('east')}")
    print(f"is_valid_facing('west'): {is_valid_facing('west')}")
    print(f"is_valid_facing('front'): {is_valid_facing('front')}")
    print(f"is_valid_facing('back'): {is_valid_facing('back')}")
    print(f"is_valid_facing('into'): {is_valid_facing('into')}")
    print(f"is_valid_facing('out'): {is_valid_facing('out')}")
    print(f"is_valid_facing('inside'): {is_valid_facing('inside')}")
    print(f"is_valid_facing('outside'): {is_valid_facing('outside')}")
    print(f"is_valid_facing('forward'): {is_valid_facing('forward')}")
    print(f"is_valid_facing('backward'): {is_valid_facing('backward')}")
    print(f"is_valid_facing(None): {is_valid_facing(None)}")
    print(f"is_valid_facing('fasdf'): {is_valid_facing('fasdf')}")
    
def test_validate_cogmap_format():
    print("Testing validate_cogmap_format")
    # Test complex format
    json_obj = {
        "objects": [
            {"name": "object1", "position": [1, 2], "facing": "up"},
            {"name": "object2", "position": [3, 4], "facing": "down"}
        ],
        "views": [
            {"name": "view1", "position": [1, 2], "facing": "up"},
            {"name": "view2", "position": [3, 4], "facing": "down"}
        ]
    }
    print(validate_cogmap_format(json_obj))

    # Test simple format
    json_obj = {
        "object1": {"position": [1, 2], "facing": "up"},
        "object2": {"position": [3, 4], "facing": "down"},
        "object3": {"position": [5, 6]},
        "several_bookcases": [{"position": [3, 2]}, {"position": [7, 2]}, {"position": [8, 2]}]
    }
    print(validate_cogmap_format(json_obj))

    # test truncate_position_list_into_one
    print(validate_cogmap_format(trucate_object_position(json_obj)))

def test_is_valid_position():
    print("Testing is_valid_position")
    print(is_valid_position([1, 2]))
    print(is_valid_position([1, 2, 3]))
    print(is_valid_position("not a list"))

def test_determine_format_from_filename():
    print("Testing determine_format_from_filename")
    print(determine_format_from_filename("shorten_cogmap.json"))
    print(determine_format_from_filename("cogmap.json"))

def test_answer_extractions():
    # Test extract_answer
    test_text = "Answer: C. School bag and TV cabinet D. "
    print(f"Extracted answer: {extract_answer(test_text)}")

    # Test get_setting_from_id
    test_id = "around_rotation_translation_among_1" 
    print(f"Setting from ID: {get_setting_from_id(test_id)}")

    # Test determine_answer_field
    test_item = {
        "cogmap_gen_answer": "C. School bag and TV cabinet",
        "cogmap_answer": "C. School bag and TV cabinet",    
        "plain_answer": "C. School bag and TV cabinet"
    }
    print(f"Determined answer field: {determine_answer_field('cogmap_gen')}")

    # Test extract_model_and_version
    test_filename = "gpt4o_mini_v2_gen_cogmap.jsonl"    
    print(f"Extracted model and version: {extract_model_and_version(test_filename)}")

if __name__ == "__main__":
    test_extract_json_from_text()
    test_is_complex_format()
    test_is_valid_position()
    test_is_valid_facing()
    test_validate_cogmap_format()
    test_determine_format_from_filename()
    test_answer_extractions()